package com.sd.foudation.cocurrency.mysql;

import com.alibaba.druid.pool.DruidDataSource;
import org.junit.Assert;
import org.junit.Test;

import javax.sql.DataSource;
import java.sql.*;
import java.util.Random;
import java.util.concurrent.CountDownLatch;

/**
 * Created by da on 2016-10-31.
 */
public class MysqlRaceCondition {

    public DataSource getDataSource(){
        DruidDataSource dataSource = new DruidDataSource();
        dataSource.setUsername("root");
        dataSource.setPassword("root");

        dataSource.setUrl("jdbc:mysql://localhost/test");
        dataSource.setMaxActive(20);
        return dataSource;
    }

//    static final Integer ID = 1
    /**
     * select for update
     */
    public void rowLock() throws SQLException {
        Connection connection  = getDataSource().getConnection();

        boolean autoCommit = connection.getAutoCommit();
        connection.setAutoCommit(false);
        Statement statement = connection.createStatement();
        ResultSet resultSet = statement.executeQuery("select * from stock where id = 1 and visible = 1 for update");
        resultSet.next();
        int current_value = resultSet.getInt("current_value");
        statement.execute("update stock set visible = 0 where id = 1");
        /**
         * do other things
         */
        try {
            Thread.sleep(100L);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        statement.execute("update stock set current_value ="+ (current_value-1) +" where id = 1");
        statement.execute("update stock set visible = 1 where id = 1");
        connection.commit();
        connection.setAutoCommit(autoCommit);
    }

    public void withoutRowLock() throws SQLException {
        Connection connection  = getDataSource().getConnection();

        Statement statement = connection.createStatement();
        ResultSet resultSet = statement.executeQuery("select * from stock where id = 1 ");
        resultSet.next();
        int current_value = resultSet.getInt("current_value");
        /**
         * do other things
         */
        try {
            Thread.sleep(100L);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        statement.execute("update stock set current_value =" + (current_value - 1) + " where id = 1");
    }

    @Test
    public void testRowLock() throws SQLException {
        /**
         * prepare
         */
        Connection connection  = getDataSource().getConnection();

        Statement statement = connection.createStatement();
        statement.execute("update stock set current_value = 200 where id=1");

        /**
         * do test
         */
        int try_times = 10;
        final CountDownLatch countDownLatch = new CountDownLatch(try_times);
        for (int i = 0; i < try_times; i++) {
            new Thread(new Runnable() {
                public void run() {
                    try {
                        rowLock();
                        countDownLatch.countDown();
                    } catch (SQLException e) {
                        e.printStackTrace();
                    }
                }
            }).start();
        }
        try {
            countDownLatch.await();
            ResultSet resultSet = statement.executeQuery("select * from stock where id = 1 and visible = 1");
            resultSet.next();
            int current_value = resultSet.getInt("current_value");
            Assert.assertEquals((200 - try_times), current_value);
            System.out.println(">>>>>>>>>>>>after consume>>>>>>>>" + current_value);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println("test >>>>>>>>>>>>>");
    }
    @Test
    public void testWithoutRowLock() throws SQLException {
        /**
         * prepare
         */
        Connection connection  = getDataSource().getConnection();

        Statement statement = connection.createStatement();
        statement.execute("update stock set current_value = 200 where id=1");

        /**
         * do test
         */
        int try_times = 10;
        final CountDownLatch countDownLatch = new CountDownLatch(try_times);
        for (int i = 0; i < try_times; i++) {
            new Thread(new Runnable() {
                public void run() {
                    try {
                        withoutRowLock();
                        countDownLatch.countDown();
                    } catch (SQLException e) {
                        e.printStackTrace();
                    }
                }
            }).start();
        }
        try {
            countDownLatch.await();
            ResultSet resultSet = statement.executeQuery("select * from stock where id = 1 and visible = 1");
            resultSet.next();
            int current_value = resultSet.getInt("current_value");
            Assert.assertNotEquals((200 - try_times), current_value);
            System.out.println(">>>>>>>>>>>>after consume>>>>>>>>"+current_value);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println("test >>>>>>>>>>>>>");
    }

    public int updateWithVersion(Connection connection) throws SQLException,IllegalStateException {
        Statement statement = connection.createStatement();
        ResultSet resultSet = statement.executeQuery("select * from stock where id = 1");
        resultSet.next();
        int version = resultSet.getInt("version");
        int current_value = resultSet.getInt("current_value");

        /**
         * do other things
         */
        try {
            Thread.sleep(new Random().nextInt(100)+100L);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        String sql = "update stock set current_value = %d,version = %d where id = 1 and version = %d";
        sql = String.format(sql,current_value-1,version+1,version);
        int result = statement.executeUpdate(sql);
        return result;

    }
    /**;
     * using version control
     */
    public void versionfy() throws SQLException{

        Connection connection = getDataSource().getConnection();
        /**
         *  Max retry time .
         */
        int max_retry_times = 10;
        for (int i = 0; i < max_retry_times; i++) {
            int result = updateWithVersion(connection);
            if (result == 1){
                break;
            }else {
                System.out.println("race condition happened <<<<<<<<");
            }
        }
    }

    @Test
    public void testVersionfy() throws SQLException {
        /**
         * prepare
         */
        Connection connection  = getDataSource().getConnection();

        Statement statement = connection.createStatement();
        statement.execute("update stock set current_value = 200 where id=1");

        /**
         * do test
         */
        int try_times = 10;
        final CountDownLatch countDownLatch = new CountDownLatch(try_times);
        for (int i = 0; i < try_times; i++) {
            new Thread(new Runnable() {
                public void run() {
                    try {
                        versionfy();
                        countDownLatch.countDown();
                    } catch (SQLException e) {
                        e.printStackTrace();
                    }
                }
            }).start();
        }

        try {
            countDownLatch.await();
            ResultSet resultSet = statement.executeQuery("select * from stock where id = 1 and visible = 1");
            resultSet.next();
            int current_value = resultSet.getInt("current_value");
            Assert.assertEquals((200 - try_times), current_value);
            System.out.println(">>>>>>>>>>>>after consume>>>>>>>>" + current_value);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println("test >>>>>>>>>>>>>");
    }

}
